package srtp

import (
	"crypto/aes"
	"crypto/cipher"
	"encoding/binary"
	"errors"

	"github.com/adalkiran/webrtc-nuts-and-bolts/src/rtp"
)

// https://github.com/pion/srtp/blob/e338637eb5c459e0e43daf9c88cf28dd441eeb7c/context.go#L9
const (
	labelSRTPEncryption        = 0x00
	labelSRTPAuthenticationTag = 0x01
	labelSRTPSalt              = 0x02

	labelSRTCPEncryption        = 0x03
	labelSRTCPAuthenticationTag = 0x04
	labelSRTCPSalt              = 0x05

	seqNumMedian = 1 << 15
	seqNumMax    = 1 << 16
)

type GCM struct {
	srtpGCM, srtcpGCM   cipher.AEAD
	srtpSalt, srtcpSalt []byte
}

// NewGCM creates an SRTP GCM Cipher
func NewGCM(masterKey, masterSalt []byte) (*GCM, error) {
	srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey))
	if err != nil {
		return nil, err
	}
	srtpBlock, err := aes.NewCipher(srtpSessionKey)
	if err != nil {
		return nil, err
	}

	srtpGCM, err := cipher.NewGCM(srtpBlock)
	if err != nil {
		return nil, err
	}
	srtcpSessionKey, err := aesCmKeyDerivation(labelSRTCPEncryption, masterKey, masterSalt, 0, len(masterKey))
	if err != nil {
		return nil, err
	}

	srtcpBlock, err := aes.NewCipher(srtcpSessionKey)
	if err != nil {
		return nil, err
	}

	srtcpGCM, err := cipher.NewGCM(srtcpBlock)
	if err != nil {
		return nil, err
	}

	srtpSalt, err := aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt))
	if err != nil {
		return nil, err
	}

	srtcpSalt, err := aesCmKeyDerivation(labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt))

	if err != nil {
		return nil, err
	}

	return &GCM{
		srtpGCM:   srtpGCM,
		srtpSalt:  srtpSalt,
		srtcpGCM:  srtcpGCM,
		srtcpSalt: srtcpSalt,
	}, nil
}

func (g *GCM) rtpInitializationVector(header *rtp.Header, roc uint32) []byte {
	iv := make([]byte, 12)
	binary.BigEndian.PutUint32(iv[2:], header.SSRC)
	binary.BigEndian.PutUint32(iv[6:], roc)
	binary.BigEndian.PutUint16(iv[10:], header.SequenceNumber)

	for i := range iv {
		iv[i] ^= g.srtpSalt[i]
	}
	return iv
}

func (g *GCM) Decrypt(packet *rtp.Packet, roc uint32) ([]byte, error) {
	ciphertext := packet.RawData

	dst := make([]byte, len(ciphertext))
	copy(dst, ciphertext)
	aeadAuthTagLen := 16
	resultLength := len(ciphertext) - aeadAuthTagLen
	if resultLength < len(dst) {
		dst = dst[:resultLength]
	}

	iv := g.rtpInitializationVector(packet.Header, roc)

	if _, err := g.srtpGCM.Open(
		dst[packet.HeaderSize:packet.HeaderSize], iv, ciphertext[packet.HeaderSize:], ciphertext[:packet.HeaderSize],
	); err != nil {
		return nil, err
	}

	copy(dst[:packet.HeaderSize], ciphertext[:packet.HeaderSize])
	return dst, nil
}

// https://github.com/pion/srtp/blob/3c34651fa0c6de900bdc91062e7ccb5992409643/key_derivation.go#L8
func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr int, outLen int) ([]byte, error) {
	if indexOverKdr != 0 {
		// 24-bit "index DIV kdr" must be xored to prf input.
		return nil, errors.New("non-zero kdr not supported")
	}

	// https://tools.ietf.org/html/rfc3711#appendix-B.3
	// The input block for AES-CM is generated by exclusive-oring the master salt with the
	// concatenation of the encryption key label 0x00 with (index DIV kdr),
	// - index is 'rollover count' and DIV is 'divided by'

	nMasterKey := len(masterKey)
	nMasterSalt := len(masterSalt)

	prfIn := make([]byte, nMasterKey)
	copy(prfIn[:nMasterSalt], masterSalt)

	prfIn[7] ^= label

	// The resulting value is then AES encrypted using the master key to get the cipher key.
	block, err := aes.NewCipher(masterKey)
	if err != nil {
		return nil, err
	}

	out := make([]byte, ((outLen+nMasterKey)/nMasterKey)*nMasterKey)
	var i uint16
	for n := 0; n < outLen; n += nMasterKey {
		binary.BigEndian.PutUint16(prfIn[nMasterKey-2:], i)
		block.Encrypt(out[n:n+nMasterKey], prfIn)
		i++
	}
	return out[:outLen], nil
}
